// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/time.h>

#include <fcntl.h>
#include <inttypes.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>

#include <errno.h>
#include <stdint.h>

#include <zircon/boot/netboot.h>

#include "bootserver.h"

#define DEFAULT_US_BETWEEN_PACKETS 20

static uint32_t cookie = 1;
static const int MAX_READ_RETRIES = 10;
static const int MAX_SEND_RETRIES = 10000;

static int io_rcv(int s, nbmsg* msg, nbmsg* ack) {
    for (int i = 0; i < MAX_READ_RETRIES; i++) {
        bool retry_allowed = i + 1 < MAX_READ_RETRIES;

        ssize_t r = read(s, ack, 2048);
        if (r < 0) {
            if (retry_allowed && errno == EAGAIN) {
                continue;
            }
            fprintf(stderr, "\n%s: error: Socket read error %d\n", appname, errno);
            return -1;
        }
        if ((size_t)r < sizeof(nbmsg)) {
            fprintf(stderr, "\n%s: error: Read too short\n", appname);
            return -1;
        }
#ifdef DEBUG
        fprintf(stdout, " < magic = %08x, cookie = %08x, cmd = %08x, arg = %08x\n",
                ack->magic, ack->cookie, ack->cmd, ack->arg);
#endif

        if (ack->magic != NB_MAGIC) {
            fprintf(stderr, "\n%s: error: Bad magic\n", appname);
            return 0;
        }
        if (msg) {
            if (ack->cookie > msg->cookie) {
                fprintf(stderr, "\n%s: error: Bad cookie\n", appname);
                return 0;
            }
        }

        if (ack->cmd == NB_ACK || ack->cmd == NB_FILE_RECEIVED) {
            return 0;
        }

        switch (ack->cmd) {
        case NB_ERROR:
            fprintf(stderr, "\n%s: error: Generic error\n", appname);
            break;
        case NB_ERROR_BAD_CMD:
            fprintf(stderr, "\n%s: error: Bad command\n", appname);
            break;
        case NB_ERROR_BAD_PARAM:
            fprintf(stderr, "\n%s: error: Bad parameter\n", appname);
            break;
        case NB_ERROR_TOO_LARGE:
            fprintf(stderr, "\n%s: error: File too large\n", appname);
            break;
        case NB_ERROR_BAD_FILE:
            fprintf(stderr, "\n%s: error: Bad file\n", appname);
            break;
        default:
            fprintf(stderr, "\n%s: error: Unknown command 0x%08X\n", appname, ack->cmd);
        }
        return -1;
    }
    fprintf(stderr, "\n%s: error: Unexpected code path\n", appname);
    return -1;
}

static int io_send(int s, nbmsg* msg, size_t len) {
    for (int i = 0; i < MAX_SEND_RETRIES; i++) {
#if defined(__APPLE__)
        bool retry_allowed = i + 1 < MAX_SEND_RETRIES;
#endif

        int r = write(s, msg, len);
        if (r < 0) {
#if defined(__APPLE__)
            if (retry_allowed && errno == ENOBUFS) {
                // On Darwin we manage to overflow the ethernet driver, so retry
                struct timespec reqtime;
                reqtime.tv_sec = 0;
                reqtime.tv_nsec = 50 * 1000;
                nanosleep(&reqtime, NULL);
                continue;
            }
#endif
            fprintf(stderr, "\n%s: error: Socket write error %d\n", appname, errno);
            return -1;
        }
        return 0;
    }
    fprintf(stderr, "\n%s: error: Unexpected code path\n", appname);
    return -1;
}

static int io(int s, nbmsg* msg, size_t len, nbmsg* ack, bool wait_reply) {
    int r, n;
    struct timeval tv;
    fd_set reads, writes;
    fd_set* ws = NULL;
    fd_set* rs = NULL;

    ack->cookie = 0;
    ack->cmd = 0;
    ack->arg = 0;

    FD_ZERO(&reads);
    if (!wait_reply) {
        FD_SET(s, &reads);
        rs = &reads;
    }

    FD_ZERO(&writes);
    if (msg && len > 0) {
        msg->magic = NB_MAGIC;
        msg->cookie = cookie++;

        FD_SET(s, &writes);
        ws = &writes;
    }

    if (rs || ws) {
        n = s + 1;
        tv.tv_sec = 10;
        tv.tv_usec = 500000;
        int rv = select(n, rs, ws, NULL, &tv);
        if (rv == -1) {
            fprintf(stderr, "\n%s: error: Select failed %d\n", appname, errno);
            return -1;
        } else if (rv == 0) {
            // Timed-out
            fprintf(stderr, "\n%s: error: Select timed out\n", appname);
            return -1;
        } else {
            r = 0;
            if (FD_ISSET(s, &reads)) {
                r = io_rcv(s, msg, ack);
            }

            // If we got an ack, don't bother sending anything - go handle the ack first
            if (!r && FD_ISSET(s, &writes) && (ack->cookie == 0 || ack->cmd != NB_ACK)) {
                r = io_send(s, msg, len);
            }

            if (r || !wait_reply) {
                return r;
            }
        }
    } else if (!wait_reply) { // no-op
        return 0;
    }

    if (wait_reply) {
        return io_rcv(s, msg, ack);
    }
    fprintf(stderr, "\n%s: error: Select triggered without events\n", appname);
    return -1;
}

typedef struct {
    FILE* fp;
    const char* data;
    size_t datalen;
    const char* ptr;
    size_t avail;
} xferdata;

static ssize_t xread(xferdata* xd, void* data, size_t len) {
    if (xd->fp == NULL) {
        if (len > xd->avail) {
            len = xd->avail;
        }
        memcpy(data, xd->ptr, len);
        xd->avail -= len;
        xd->ptr += len;
        return len;
    } else {
        ssize_t r = fread(data, 1, len, xd->fp);
        if (r == 0) {
            return ferror(xd->fp) ? -1 : 0;
        }
        return r;
    }
}

static int xseek(xferdata* xd, size_t off) {
    if (xd->fp == NULL) {
        if (off > xd->datalen) {
            return -1;
        }
        xd->ptr = xd->data + off;
        xd->avail = xd->datalen - off;
        return 0;
    } else {
        return fseek(xd->fp, off, SEEK_SET);
    }
}

// UDP6_MAX_PAYLOAD (ETH_MTU - ETH_HDR_LEN - IP6_HDR_LEN - UDP_HDR_LEN)
//      1452           1514   -     14      -     40      -    8
// nbfile is PAYLOAD_SIZE + 2 * sizeof(size_t)

// Some EFI network stacks have problems with larger packets
// 1280 is friendlier
#define PAYLOAD_SIZE 1280

int netboot_xfer(struct sockaddr_in6* addr, const char* fn, const char* name) {
    xferdata xd;
    char msgbuf[2048];
    char ackbuf[2048];
    char tmp[INET6_ADDRSTRLEN];
    struct timeval tv;
    nbmsg* msg = (void*)msgbuf;
    nbmsg* ack = (void*)ackbuf;
    int s;
    int status = -1;
    size_t current_pos = 0;
    size_t sz = 0;

    if (!strcmp(fn, "(cmdline)")) {
        xd.fp = NULL;
        xd.data = name;
        xd.datalen = strlen(name) + 1;
        xd.ptr = xd.data;
        xd.avail = xd.datalen;
        name = NB_CMDLINE_FILENAME;
        sz = xd.datalen;
    } else {
        if ((xd.fp = fopen(fn, "rb")) == NULL) {
            fprintf(stderr, "%s: error: Could not open file %s\n", appname, fn);
            return -1;
        }
        if (fseek(xd.fp, 0L, SEEK_END)) {
            fprintf(stderr, "%s: error: Could not determine size of %s\n", appname, fn);
        } else if ((sz = ftell(xd.fp)) < 0) {
            fprintf(stderr, "%s: error: Could not determine size of %s\n", appname, fn);
            sz = 0;
        } else if (fseek(xd.fp, 0L, SEEK_SET)) {
            fprintf(stderr, "%s: error: Failed to rewind %s\n", appname, fn);
            return -1;
        }
    }

    if (sz > 0) {
        initialize_status(xd.data, sz);
    }

    if ((s = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)) < 0) {
        fprintf(stderr, "%s: error: Cannot create socket %d\n", appname, errno);
        goto done;
    }
    tv.tv_sec = 0;
    tv.tv_usec = 250 * 1000;
    setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
    if (connect(s, (void*)addr, sizeof(*addr)) < 0) {
        fprintf(stderr, "%s: error: Cannot connect to [%s]%d\n", appname,
                inet_ntop(AF_INET6, &addr->sin6_addr, tmp, sizeof(tmp)),
                ntohs(addr->sin6_port));
        goto done;
    }

    msg->cmd = NB_SEND_FILE;
    msg->arg = sz;
    strcpy((void*)msg->data, name);
    if (io(s, msg, sizeof(nbmsg) + strlen(name) + 1, ack, true)) {
        fprintf(stderr, "%s: error: Failed to start transfer\n", appname);
        goto done;
    }

    msg->cmd = NB_DATA;
    msg->arg = 0;

    bool completed = false;
    do {
        struct timeval packet_start_time;
        gettimeofday(&packet_start_time, NULL);

        ssize_t r = xread(&xd, msg->data, PAYLOAD_SIZE);
        if (r < 0) {
            fprintf(stderr, "\n%s: error: Reading '%s'\n", appname, fn);
            goto done;
        }

        update_status(msg->arg);

        if (r == 0) {
            fprintf(stderr, "\n%s: Reached end of file, waiting for confirmation.\n", appname);
            // Do not send anything, but keep waiting on incoming messages
            if (io(s, NULL, 0, ack, true)) {
                goto done;
            }
        } else {
            if (current_pos + (size_t)r >= sz) {
                msg->cmd = NB_LAST_DATA;
            } else {
                msg->cmd = NB_DATA;
            }

            if (io(s, msg, sizeof(nbmsg) + r, ack, false)) {
                goto done;
            }

            // Some UEFI netstacks can lose back-to-back packets at max speed
            // so throttle output.
            // At 1280 bytes per packet, we should at least have 10 microseconds
            // between packets, to be safe using 20 microseconds here.
            // 1280 bytes * (1,000,000/10) seconds = 128,000,000 bytes/seconds = 122MB/s = 976Mb/s
            // We wait as a busy wait as the context switching a sleep can cause
            // will often degrade performance significantly.
            int64_t us_since_last_packet;
            do {
                struct timeval now;
                gettimeofday(&now, NULL);
                us_since_last_packet = (int64_t)(now.tv_sec - packet_start_time.tv_sec) * 1000000 +
                                       ((int64_t)now.tv_usec - (int64_t)packet_start_time.tv_usec);
            } while (us_since_last_packet < us_between_packets);
        }

        // ACKs really are NACKs
        if (ack->cookie > 0 && ack->cmd == NB_ACK) {
            // ACKs tend to be generated in groups, since a dropped packet will cause ACKs for all
            // outstanding packets. Therefore briefly sleep when we receive an ACK with a different
            // position, to let things settle and prevent ourselves from fighting subsequent acks.
            if (ack->arg != current_pos) {
                fprintf(stderr, "\n%s: need to reset to %d from %zu\n",
                        appname, ack->arg, current_pos);
                current_pos = ack->arg;

                tv.tv_usec = 100000;
                select(0, NULL, NULL, NULL, &tv);
            }
            if (xseek(&xd, current_pos)) {
                fprintf(stderr, "\n%s: error: Failed to rewind '%s' to %zu\n",
                        appname, fn, current_pos);
                goto done;
            }
        } else if (ack->cmd == NB_FILE_RECEIVED) {
            current_pos += r;
            completed = true;
        } else {
            current_pos += r;
        }

        msg->arg = current_pos;
    } while (!completed);

    status = 0;
    update_status(msg->arg);
done:
    if (s >= 0) {
        close(s);
    }
    if (xd.fp != NULL) {
        fclose(xd.fp);
    }
    return status;
}
